{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyvista as pv\n",
    "import panel as pn\n",
    "import os\n",
    "import os.path as osp\n",
    "pv.set_plot_theme(\"document\")\n",
    "\n",
    "pn.extension('vtk')\n",
    "os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')\n",
    "os.environ['DISPLAY'] = ':99'\n",
    "os.environ['PYVISTA_OFF_SCREEN'] = 'True'\n",
    "os.environ['PYVISTA_USE_PANEL'] = 'True'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from plyfile import PlyData\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img width=\"40%\" src=\"https://raw.githubusercontent.com/nicolas-chaulet/torch-points3d/master/docs/logo.png\" />\n",
    "</p>\n",
    "\n",
    "# Registration Demo on 3DMatch\n",
    "\n",
    "In this task, we will show a demonstration of registration on 3DMatch using a pretrained network from scratch.\n",
    "\n",
    "First let's load some examples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We read the data\n",
    "\n",
    "def read_ply(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        plydata = PlyData.read(f)\n",
    "    vertex = plydata['vertex']\n",
    "    return np.vstack((vertex['x'], vertex['y'], vertex['z'])).T\n",
    "path_s = \"data/3DMatch/redkitchen_000.ply\"\n",
    "path_t = \"data/3DMatch/redkitchen_010.ply\"\n",
    "\n",
    "pcd_s = read_ply(path_s)\n",
    "pcd_t = read_ply(path_t)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can put the point cloud in the class Batch, apply some transformation (transform data into sparse voxels, add ones). We can load the model too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preprocessing import\n",
    "from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey\n",
    "from torch_geometric.transforms import Compose\n",
    "from torch_geometric.data import Batch\n",
    "\n",
    "# Model\n",
    "from torch_points3d.applications.pretrained_api import PretainedRegistry\n",
    "\n",
    "# post processing\n",
    "from torch_points3d.utils.registration import get_matches, fast_global_registration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = Compose([GridSampling3D(mode='last', size=0.02, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name=\"ones\")])\n",
    "data_s = transform(Batch(pos=torch.from_numpy(pcd_s).float(), batch=torch.zeros(pcd_s.shape[0]).long()))\n",
    "data_t = transform(Batch(pos=torch.from_numpy(pcd_t).float(), batch=torch.zeros(pcd_t.shape[0]).long()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This will log some errors, don't worry it's all good!\n",
    "model = PretainedRegistry.from_pretrained(\"minkowski-registration-3dmatch\").cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    model.set_input(data_s, \"cuda\")\n",
    "    output_s = model.forward()\n",
    "    model.set_input(data_t, \"cuda\")\n",
    "    output_t = model.forward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Now we have our feature let's match our features. We will select 5000 points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_s = torch.randint(0, len(output_s), (5000, ))\n",
    "rand_t = torch.randint(0, len(output_t), (5000, ))\n",
    "\n",
    "matches = get_matches(output_s[rand_s], output_t[rand_t])\n",
    "\n",
    "T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]], data_t.pos[rand_t][matches[:, 1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pl1 = pv.Plotter(notebook=True)\n",
    "pl1.add_points(data_s.pos.numpy(),color=[0.9, 0.7, 0.1])\n",
    "pl1.add_points(data_t.pos.numpy(),color=[0.1, 0.7, 0.9])\n",
    "pl1.enable_eye_dome_lighting()\n",
    "pan1 = pn.panel(pl1.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)\n",
    "pan1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pl1 = pv.Plotter(notebook=True)\n",
    "transformed = data_s.pos @ T_est[:3, :3].T + T_est[:3, 3]\n",
    "pl1.add_points(np.asarray(transformed.cpu().numpy()),color=[0.9, 0.7, 0.1])\n",
    "pl1.add_points(np.asarray(data_t.pos.cpu().numpy()),color=[0.1, 0.7, 0.9])\n",
    "pan1 = pn.panel(pl1.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)\n",
    "pan1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization\n",
    "Let's try to visualize features to see what the network have learnt.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "def compute_color_from_features(list_feat):\n",
    "    feats = np.vstack(list_feat)\n",
    "    pca = PCA(n_components=3)\n",
    "    pca.fit(feats)\n",
    "    min_col = pca.transform(feats).min(axis=0)\n",
    "    max_col = pca.transform(feats).max(axis=0)\n",
    "    list_color = []\n",
    "    for feat in list_feat:\n",
    "        color = pca.transform(feat)\n",
    "        color = (color - min_col) / (max_col - min_col)\n",
    "        list_color.append(color)\n",
    "    return list_color\n",
    "list_color = compute_color_from_features([output_s.detach().cpu().numpy(), output_t.detach().cpu().numpy()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pl1 = pv.Plotter(notebook=True)\n",
    "point_cloud = pv.PolyData(np.asarray(transformed.cpu().numpy()))\n",
    "point_cloud[\"color\"] = list_color[0]\n",
    "pl1.add_points(point_cloud)\n",
    "pl2 = pv.Plotter(notebook=True)\n",
    "point_cloud = pv.PolyData(data_t.pos.cpu().numpy())\n",
    "point_cloud[\"color\"] = list_color[1]\n",
    "pl2.add_points(point_cloud)\n",
    "pan1 = pn.panel(pl1.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)\n",
    "pan2 = pn.panel(pl2.ren_win, sizing_mode='scale_both', aspect_ratio=1,orientation_widget=True,)\n",
    "pn.Row(pan1,pan2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
